#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK

# Copyright (C) 2020  GreenWaves Technologies, SAS

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import argparse
import importlib
import os
import sys
from pathlib import Path

import argcomplete

from nntool._version import __version__
from nntool.interpreter.commands.open_parser import add_open_options


def create_parser():
    # create the top-level parser
    parser = argparse.ArgumentParser(prog='nntool')

    parser.add_argument('graph_file',
                        help='graph file - TFLite file or JSON state file',
                        metavar="INPUT_GRAPH or STATE_FILE",
                        nargs=argparse.OPTIONAL,
                        default="")
    exclus = parser.add_mutually_exclusive_group()
    exclus.add_argument('-g', '--generate',
                        help='generate code from a state file',
                        action='store_true')
    parser.add_argument('-c', '--template_file',
                        help='use this code template for generation',
                        metavar="TEMPLATE_FILE")
    parser.add_argument('-s', '--script_file',
                        help='run script and exit',
                        metavar="SCRIPT_FILE")
    parser.add_argument('-H', '--header_file',
                        help='write graph information to header file')
    parser.add_argument('-m', '--model_file',
                        help='override model file in state file')
    parser.add_argument('-M', '--model_directory',
                        help='override model directory in state file')
    parser.add_argument('-T', '--tensor_directory',
                        help='override tensor directory in state file')
    parser.add_argument('-d', '--dont_dump_tensors',
                        action='store_true',
                        help='don\'t dump tensors')
    parser.add_argument('--anonymise',
                        action='store_true',
                        help='try to make names anonymous')
    parser.add_argument('-l', '--log_level', choices=['debug', 'info', 'warning'],
                        help='set logging level', default=None)
    parser.add_argument('-v', '--version',
                        action='store_true',
                        help='show version number')
    parser.add_argument('--basic_kernel_header_file',
                        default='Expression_Kernels.h',
                        help='filename for basic kernel headers - defaults to Expression_Kernels.h')
    parser.add_argument('--basic_kernel_source_file',
                        default='Expression_Kernels.c',
                        help='filename for basic kernel headers - defaults to Expression_Kernels.c')
    add_open_options(parser)
    return parser


def main():
    parser = create_parser()
    argcomplete.autocomplete(parser)
    args = parser.parse_args()
    if args.version:
        print("NNTOOL Version {}".format(__version__))
        return

    if args.generate:
        mod = importlib.import_module('nntool.interpreter.generator')
        mod.generate_code(args)
        return

    nntool_workdir = os.path.join(str(Path.home()), '.nntool')
    os.makedirs(nntool_workdir, exist_ok=True, mode=0o755)
    # late import to speed up argcomplete
    mod = importlib.import_module('nntool.interpreter.nntool_shell')
    c = mod.NNToolShell(args,
                        persistent_history_file=os.path.join(
                            nntool_workdir, "history"),
                        allow_cli_args=False)
    if args.script_file:
        sys.exit(c.run_script(args.script_file))
    else:
        sys.exit(c.cmdloop())


if __name__ == '__main__':
    main()
